#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
sim_runner.py (fixed)
---------------------
Discrete-event simulation for corridor service with optional off-corridor excursions.

Bug fixes & improvements (relative to your sim_runner (4).py):
- Implements return-to-corridor window = min{N_max corridor links, T_max seconds}
  as a single *time* deadline. We compute T_links by summing nominal T0 over the
  next N corridor links from the parent stop, then set deadline = t0 + min(T_max, T_links).
- Feasibility now includes expected side dwell (mean 'alpha') so borderline accepted
  excursions can still miss the deadline in realization.
- Dwell parameter keys are tolerant: supports alpha/alpha_s, beta_b/beta_board[/_s],
  beta_a/beta_alight[/_s], sigma/noise_sigma[/_s].
- Stops loader recognizes 'x_along_corridor_m' in addition to 'x_along_corridor'.
- Seed manifest tolerant to an alternate {'seeds': {'requests','links','departures'}} format.
- Keeps deterministic vehicle–departure mapping, and the candidate pool restriction
  (only operating vehicles are considered for assignments).

Outputs per run: events.csv, decisions.csv, kpis.csv, run_meta.json
"""

import argparse
import json
import math
import os
import sys
import uuid
import hashlib
import datetime as dt
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any
import heapq
import time
import zipfile
import csv

import numpy as np
import pandas as pd
import yaml

# Optional progress bars
try:
    from tqdm import tqdm
    _HAS_TQDM = True
except Exception:
    _HAS_TQDM = False

# ------------- Utilities -------------

def log_print(*args, **kwargs):
    ts = dt.datetime.now().strftime('%H:%M:%S')
    print(f"[{ts}]", *args, **kwargs)

def sha256_of_text(txt: str) -> str:
    return hashlib.sha256(txt.encode("utf-8")).hexdigest()

def now_iso() -> str:
    return dt.datetime.now().isoformat(timespec="seconds")

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def read_csv(path: str) -> pd.DataFrame:
    return pd.read_csv(path)

def read_yaml(path: str) -> dict:
    with open(path, "r") as f:
        return yaml.safe_load(f)

def read_json(path: str) -> dict:
    with open(path, "r") as f:
        return json.load(f)

def write_json(obj: dict, path: str):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

# ------------- Data Classes -------------

@dataclass
class Stop:
    stop_id: str
    x: float
    line_id: str
    in_overlap: bool
    stop_type: str  # 'corridor' | 'side'
    parent_corr_stop_id: Optional[str]
    stop_order: int

@dataclass
class ExcursionArc:
    from_corr_stop_id: str
    to_side_stop_id: str
    distance_m: float
    turn_penalty_s: float
    T0: float
    sigma: float  # lognormal std on multiplicative factor

@dataclass
class Link:
    from_stop: str
    to_stop: str
    T0: float
    sigma: float
    control_point_flag: int
    link_type: str  # 'corridor' | 'excursion'

@dataclass
class Departure:
    dep_id: str
    line_id: str
    t_depart: float

@dataclass
class Request:
    req_id: str
    t_request: float
    o_stop: str
    d_stop: str
    group_size: int
    patience: float
    o_type: str
    d_type: str

@dataclass
class Vehicle:
    vehicle_id: str
    line_id: str
    capacity: int
    # dynamic state
    onboard: int = 0
    loc_stop: Optional[str] = None
    next_event_time: float = 0.0
    status: str = "idle"  # 'idle','reserved','enroute','dwell','excursion'
    current_link: Optional[Tuple[str, str]] = None
    headway_clock: float = 0.0
    excursion_budget_s: float = 0.0
    return_window_links: int = 0
    return_window_time_s: float = 0.0
    excursion_quota_left: Optional[int] = None

@dataclass(order=True)
class Event:
    time: float
    priority: int
    typ: str
    payload: Any=field(compare=False)

# ------------- Policies -------------

class PolicyBase:
    name = "base"
    def __init__(self, policy_params: dict):
        self.params = policy_params or {}

    def decide_accept(self, ctx: dict, vehicle: Vehicle, req: Request, feas: dict) -> Tuple[bool, str]:
        raise NotImplementedError

class BaselineStatusQuo(PolicyBase):
    name = "baseline"
    def decide_accept(self, ctx, vehicle, req, feas):
        if not feas["is_corridor_pickup"]:
            return (False, "no_excursions")
        if not feas["within_capacity"]:
            return (False, "capacity")
        return (True, "corridor_ok")

class MyopicFeasible(PolicyBase):
    name = "myopic"
    def decide_accept(self, ctx, vehicle, req, feas):
        if not feas["within_capacity"]:
            return (False, "capacity")
        if not feas["budget_ok"]:
            return (False, "budget")
        if not feas["return_window_ok"]:
            return (False, "return_window")
        return (True, "feasible")

class SlackAware(PolicyBase):
    name = "slackaware"
    def decide_accept(self, ctx, vehicle, req, feas):
        if not feas["within_capacity"]:
            return (False, "capacity")
        if not feas["budget_ok"] or not feas["return_window_ok"]:
            return (False, "budget_or_window")
        risk = ctx["headway_risk_score"](vehicle)
        if risk > float(self.params.get("headway_risk_threshold", 0.6)):
            return (False, "headway_risk")
        quota = vehicle.excursion_quota_left
        if quota is not None and quota <= 0:
            return (False, "quota_exhausted")
        return (True, "gated_ok")

POLICY_REGISTRY = {
    "baseline": BaselineStatusQuo,
    "myopic": MyopicFeasible,
    "slackaware": SlackAware,
}

# ------------- Simulator -------------

class Simulator:
    def __init__(self, scen_dir: str, policy_name: str, replicate_id: int, out_dir: str, horizon_s: float = 10800.0):
        self.scen_dir = scen_dir
        self.policy_name = policy_name
        self.replicate_id = replicate_id
        self.out_dir = out_dir
        self.horizon_s = float(horizon_s)

        # Load inputs
        self.stops = self._load_stops()
        self.links = self._load_links()
        self.excursion_arcs = self._load_excursions()
        self.departures = self._load_departures()
        self.requests = self._load_requests()
        self.vehicles_df = read_csv(os.path.join(scen_dir, "vehicles.csv"))
        self.dwell_params = read_yaml(os.path.join(scen_dir, "dwell_params.yml")) or {}
        self.policy_params_raw = read_yaml(os.path.join(scen_dir, "policy_params.yml")) or {}
        self.policy_params = self._normalize_policy_params(self.policy_params_raw)
        self.seed_manifest = self._load_seed_manifest(os.path.join(scen_dir, "seed_manifest.json"))

        # RNG streams
        self.rng = np.random.default_rng(self.seed_manifest.get("replicate_seeds", {}).get(str(replicate_id), 12345))
        self.rng_travel = np.random.default_rng(self.seed_manifest.get("travel_seed", 4242) + replicate_id)
        self.rng_dwell = np.random.default_rng(self.seed_manifest.get("dwell_seed", 31415) + replicate_id)

        # Indexes
        self.stop_by_id: Dict[str, Stop] = {s.stop_id: s for s in self.stops}
        self.links_from: Dict[str, List[Link]] = {}
        for lk in self.links:
            self.links_from.setdefault(lk.from_stop, []).append(lk)

        # Canonical corridor stop id by (line_id, stop_order)
        self.corridor_stop_by_line_order: Dict[Tuple[str,int], str] = {
            (s.line_id, s.stop_order): s.stop_id for s in self.stops if s.stop_type == "corridor"
        }

        # Map (from_corr_stop_id, to_side_stop_id) -> ExcursionArc
        self.excursion_by_pair: Dict[Tuple[str, str], ExcursionArc] = {}
        for ea in self.excursion_arcs:
            self.excursion_by_pair[(ea.from_corr_stop_id, ea.to_side_stop_id)] = ea

        # Policy
        PolicyClass = POLICY_REGISTRY[policy_name]
        self.policy = PolicyClass(self.policy_params.get(policy_name, {}))

        # Vehicles
        self.vehicles: Dict[str, Vehicle] = {}
        self._init_vehicles()

        # Pending and active excursions
        self.pending_excursions: Dict[str, List[dict]] = {}
        self.active_excursion: Dict[str, Optional[dict]] = {}

        # Event queue
        self.t = 0.0
        self.evq: List[Event] = []

        # Streaming logs
        ensure_dir(self.out_dir)
        self._events_fp = open(os.path.join(self.out_dir, "events.csv"), "w", newline="")
        self._events_wr = csv.DictWriter(self._events_fp, fieldnames=[
            "time","type","vehicle_id","from_stop","to_stop","stop_id","tdwell","ttravel"
        ])
        self._events_wr.writeheader()

        self._dec_fp = open(os.path.join(self.out_dir, "decisions.csv"), "w", newline="")
        self._dec_wr = csv.DictWriter(self._dec_fp, fieldnames=[
            "time","req_id","vehicle_id","decision","reason","is_corridor_pickup","budget_ok","return_window_ok","within_capacity"
        ])
        self._dec_wr.writeheader()

        # Running stats for KPIs
        self._stats = {
            "depart_count": 0,
            "depart_ttravel_sum": 0.0,
            "depart_ttravel_sqsum": 0.0,
            "vehicle_depart_count": 0,
            "accept_count": 0,
            "abandon_count": 0,
            "wait_sum": 0.0,
            "wait_n": 0,
            "excursions_km": 0.0,
            "excursions_count": 0,
            "missed_return_window_count": 0,
        }

        # Request statuses
        self.req_status: Dict[str, dict] = {}

        # Run meta
        self.meta = {
            "scenario_dir": scen_dir,
            "policy": policy_name,
            "replicate_id": replicate_id,
            "simulator_version": "1.6.0-fixed",
            "input_schema_version": 1,
            "output_schema_version": 1,
            "started_at": now_iso(),
        }

    # ---------- Helpers ----------

    @staticmethod
    def _get(d: dict, *keys, default=None):
        """Return the first present key in d from keys; otherwise default."""
        for k in keys:
            if k in d:
                return d[k]
        return default

    def _load_seed_manifest(self, path: str) -> dict:
        sm = read_json(path)
        # Tolerate alternate format { "seeds": { "requests": x, "links": y, "departures": z } }
        if "replicate_seeds" not in sm and "seeds" in sm and isinstance(sm["seeds"], dict):
            base = int(sm["seeds"].get("requests", 12345))
            sm["replicate_seeds"] = {str(i): base + i for i in range(1000)}
            sm["travel_seed"] = int(sm["seeds"].get("links", 4242))
            sm["dwell_seed"]  = int(sm["seeds"].get("departures", 31415))
        return sm

    # ---------- Policy YAML normalizer ----------
    def _normalize_policy_params(self, raw: dict) -> dict:
        raw = raw or {}
        out = {"baseline": {}, "myopic": {}, "slackaware": {}}

        ex = (raw.get("excursion_budget") or {})
        rw = (raw.get("return_window") or {})
        quota = (raw.get("quota") or {})

        ex_s = float(ex.get("cap_seconds") or 0.0)
        rw_links = int(rw.get("max_corridor_links") or 0)
        rw_s = float(rw.get("max_seconds") or 0.0)

        common = {
            "excursion_budget_s": ex_s,
            "return_window_links": rw_links,
            "return_window_time_s": rw_s,
            "max_excursions_per_trip": int(quota.get("max_excursions_per_trip") or 0),
        }
        out["myopic"].update(common)
        out["slackaware"].update(common)

        for entry in (raw.get("policies") or []):
            if not isinstance(entry, dict):
                continue
            for name, params in entry.items():
                params = params or {}
                key = None
                name_str = str(name)
                if "Myopic" in name_str:
                    key = "myopic"
                elif "SlackAware" in name_str or "Slack_Aware" in name_str:
                    key = "slackaware"
                elif "Baseline" in name_str:
                    key = "baseline"
                if key:
                    out[key].update(params)

        if "headway_cv_threshold" in out["slackaware"] and "headway_risk_threshold" not in out["slackaware"]:
            try:
                out["slackaware"]["headway_risk_threshold"] = float(out["slackaware"]["headway_cv_threshold"])
            except Exception:
                pass
        if "lateness_threshold_s" in out["slackaware"] and "headway_planned_s" not in out["slackaware"]:
            try:
                out["slackaware"]["headway_planned_s"] = float(out["slackaware"]["lateness_threshold_s"])
            except Exception:
                pass

        return out

    # ---------- Loading helpers ----------

    def _pick_col(self, df: pd.DataFrame, candidates, default=None, required=False, ctx_name=""):
        for c in candidates:
            if c in df.columns:
                return c
        if required:
            raise KeyError(f"Missing required column among {candidates} for {ctx_name}. Available: {list(df.columns)}")
        return default

    def _load_stops(self) -> List[Stop]:
        df = read_csv(os.path.join(self.scen_dir, "stops.csv"))
        x_col = None
        for cand in ["x_along_corridor", "x_along_corridor_m", "x", "x_coord", "x_pos"]:
            if cand in df.columns:
                x_col = cand
                break

        def to_bool(v):
            if isinstance(v, bool):
                return v
            if pd.isna(v):
                return False
            s = str(v).strip().lower()
            if s in {"1","true","t","yes","y"}:
                return True
            if s in {"0","false","f","no","n"}:
                return False
            try:
                return bool(int(s))
            except Exception:
                return False

        rows = []
        for idx, r in df.iterrows():
            if x_col is not None and pd.notna(r.get(x_col, np.nan)):
                x_val = float(r[x_col])
            else:
                so = int(r["stop_order"]) if "stop_order" in df.columns else int(idx)
                x_val = float(so)
            parent_val = r.get("parent_corr_stop_id", None)
            if pd.isna(parent_val) or str(parent_val).strip() == "":
                parent_val = None
            rows.append(Stop(
                stop_id=str(r["stop_id"]),
                x=x_val,
                line_id=str(r["line_id"]),
                in_overlap=to_bool(r.get("in_overlap", False)),
                stop_type=str(r.get("stop_type", "corridor")),
                parent_corr_stop_id=(None if parent_val is None else str(parent_val)),
                stop_order=int(r["stop_order"]) if "stop_order" in df.columns else int(idx),
            ))
        return rows

    def _load_links(self) -> List[Link]:
        df = read_csv(os.path.join(self.scen_dir, "links.csv"))
        from_col = self._pick_col(df, ["from_stop","from","u","origin"], required=True, ctx_name="links.from_stop")
        to_col   = self._pick_col(df, ["to_stop","to","v","dest"], required=True, ctx_name="links.to_stop")
        t0_col   = self._pick_col(df, ["T0","T0_s","t0","base_time_s","mean_time_s","freeflow_s","tt0"], required=False, ctx_name="links.T0")
        sigma_col= self._pick_col(df, ["sigma","sigma_logn","sigma_ln","std_ln","cv"], required=False, ctx_name="links.sigma")
        cp_col   = self._pick_col(df, ["control_point_flag","control_point","is_cp","cp"], required=False)
        lt_col   = self._pick_col(df, ["link_type","type"], required=False)

        dist_col = self._pick_col(df, ["distance_m","dist_m","distance"], required=False)
        speed_col= self._pick_col(df, ["speed_mps","speed","v_mps"], required=False)

        rows = []
        for _, r in df.iterrows():
            if t0_col is not None and pd.notna(r.get(t0_col, np.nan)):
                T0 = float(r[t0_col])
            elif dist_col is not None and speed_col is not None and pd.notna(r.get(dist_col, np.nan)) and pd.notna(r.get(speed_col, np.nan)) and float(r[speed_col]) > 0:
                T0 = float(r[dist_col]) / float(r[speed_col])
            else:
                T0 = 0.0

            if sigma_col is not None and pd.notna(r.get(sigma_col, np.nan)):
                sigma_val = float(r[sigma_col])
            else:
                sigma_val = 0.2

            cp_flag = int(r[cp_col]) if cp_col is not None and str(r.get(cp_col, "")).strip() != "" and pd.notna(r.get(cp_col, np.nan)) else 0
            ltype = str(r.get(lt_col, "")).lower() if lt_col is not None and str(r.get(lt_col, "")).strip() != "" else "corridor"
            rows.append(Link(
                from_stop=str(r[from_col]),
                to_stop=str(r[to_col]),
                T0=float(T0),
                sigma=float(sigma_val),
                control_point_flag=int(cp_flag),
                link_type=ltype,
            ))
        return rows

    def _load_excursions(self) -> List[ExcursionArc]:
        path = os.path.join(self.scen_dir, "excursion_arcs.csv")
        if not os.path.exists(path):
            return []
        df = read_csv(path)
        fromc = self._pick_col(df, ["from_corr_stop_id","from_corr","from_stop","parent_corr_stop_id"], required=True, ctx_name="excursions.from_corr_stop_id")
        toside= self._pick_col(df, ["to_side_stop_id","to_side","side_stop_id","side_stop"], required=True, ctx_name="excursions.to_side_stop_id")
        distc = self._pick_col(df, ["distance_m","dist_m","distance"], required=False)
        tpen  = self._pick_col(df, ["turn_penalty_s","turn_s","turn_penalty"], required=False)
        t0c   = self._pick_col(df, ["T0","T0_s","t0","base_time_s","mean_time_s","freeflow_s","tt0"], required=False)
        sigc  = self._pick_col(df, ["sigma","sigma_logn","sigma_ln","std_ln","cv"], required=False)

        rows = []
        for _, r in df.iterrows():
            if t0c is not None and pd.notna(r.get(t0c, np.nan)):
                T0 = float(r[t0c])
            elif distc is not None and pd.notna(r.get(distc, np.nan)):
                T0 = float(r[distc]) / 10.0
            else:
                T0 = 0.0
            sigma_val = float(r[sigc]) if sigc is not None and pd.notna(r.get(sigc, np.nan)) else 0.25
            turn_pen = float(r[tpen]) if tpen is not None and pd.notna(r.get(tpen, np.nan)) else 0.0
            rows.append(ExcursionArc(
                from_corr_stop_id=str(r[fromc]),
                to_side_stop_id=str(r[toside]),
                distance_m=float(r[distc]) if distc is not None and pd.notna(r.get(distc, np.nan)) else 0.0,
                turn_penalty_s=turn_pen,
                T0=T0,
                sigma=sigma_val,
            ))
        return rows

    def _load_departures(self) -> List[Departure]:
        df = read_csv(os.path.join(self.scen_dir, "departures.csv"))
        dep_id_col = self._pick_col(df, ["dep_id","departure_id","id"], required=True, ctx_name="departures.dep_id")
        line_col   = self._pick_col(df, ["line_id","line","route","route_id"], required=True, ctx_name="departures.line_id")
        tdep_col   = self._pick_col(df, [
            "t_depart","t_depart_s","t","time_s","time","depart_s","depart_sec",
            "departure_time_s","departure_time","t0","seconds"
        ], required=True, ctx_name="departures.t_depart")

        vals = pd.to_numeric(df[tdep_col], errors="coerce")
        if tdep_col.endswith("_s"):
            scale = 1.0
        elif vals.notna().any():
            vmax = float(vals.max())
            scale = 60.0 if vmax <= 2880.0 else 1.0
        else:
            scale = 1.0

        rows = []
        for _, r in df.iterrows():
            td = float(r[tdep_col]) * scale if pd.notna(r[tdep_col]) else 0.0
            rows.append(Departure(
                dep_id=str(r[dep_id_col]),
                line_id=str(r[line_col]),
                t_depart=td,
            ))
        return rows

    def _load_requests(self) -> List[Request]:
        df = read_csv(os.path.join(self.scen_dir, "requests.csv"))
        id_col   = self._pick_col(df, ["req_id","request_id","id"], required=True, ctx_name="requests.req_id")
        t_col    = self._pick_col(df, [
            "t_request","t_request_s","t","time_s","time","seconds",
            "arrival_time_s","arrival_time","t_req","request_time_s","request_time"
        ], required=True, ctx_name="requests.t_request")
        o_col    = self._pick_col(df, ["o_stop","origin_stop","o","origin","from_stop"], required=True, ctx_name="requests.o_stop")
        d_col    = self._pick_col(df, ["d_stop","dest_stop","d","destination","to_stop"], required=True, ctx_name="requests.d_stop")
        g_col    = self._pick_col(df, ["group_size","party","size","n","n_pax"], required=False)
        pat_col  = self._pick_col(df, ["patience","patience_s","max_wait_s","max_wait","wait_limit_s","wait_limit","abandon_after_s"], required=False)
        ot_col   = self._pick_col(df, ["o_type","origin_type"], required=False)
        dt_col   = self._pick_col(df, ["d_type","dest_type"], required=False)

        def infer_scale(series):
            vals = pd.to_numeric(series, errors="coerce")
            if vals.notna().any():
                vmax = float(vals.max())
                return 60.0 if vmax <= 2880.0 else 1.0
            return 1.0

        t_scale = 1.0 if t_col.endswith("_s") else infer_scale(df[t_col])
        p_scale = 1.0
        if pat_col is not None:
            p_scale = 1.0 if pat_col.endswith("_s") else infer_scale(df[pat_col])

        rows = []
        for _, r in df.iterrows():
            rid = str(r[id_col])
            treq = float(r[t_col]) * t_scale if pd.notna(r[t_col]) else 0.0
            ostop = str(r[o_col])
            dstop = str(r[d_col])
            gsize = int(r[g_col]) if (g_col is not None and pd.notna(r[g_col])) else 1
            pat   = float(r[pat_col]) * p_scale if (pat_col is not None and pd.notna(r[pat_col])) else 9999.0
            otype = str(r[ot_col]) if (ot_col is not None and pd.notna(r[ot_col])) else "corridor"
            dtype = str(r[dt_col]) if (dt_col is not None and pd.notna(r[dt_col])) else "corridor"
            rows.append(Request(
                req_id=rid,
                t_request=treq,
                o_stop=ostop,
                d_stop=dstop,
                group_size=gsize,
                patience=pat,
                o_type=otype,
                d_type=dtype,
            ))
        return rows

    def _init_vehicles(self):
        defaults = self.policy_params.get(self.policy_name, {}) or {}

        def _num(v, alt):
            try:
                fv = float(v)
                if fv > 0 and not math.isnan(fv):
                    return fv
            except Exception:
                pass
            try:
                return float(alt)
            except Exception:
                return 0.0

        for _, r in self.vehicles_df.iterrows():
            v = Vehicle(
                vehicle_id=str(r["vehicle_id"]),
                line_id=str(r["line_id"]),
                capacity=int(r["capacity"]),
                onboard=0,
                loc_stop=None,
                status="idle",
                excursion_budget_s=_num(r.get("excursion_budget_s"),     defaults.get("excursion_budget_s", 0.0)),
                return_window_links=int(_num(r.get("return_window_links"),defaults.get("return_window_links", 0))),
                return_window_time_s=_num(r.get("return_window_time_s"), defaults.get("return_window_time_s", 0.0)),
                excursion_quota_left=int(r["excursion_quota"]) if "excursion_quota" in r and not pd.isna(r["excursion_quota"]) else None,
            )
            self.vehicles[v.vehicle_id] = v

    # ---------- Stochastic models ----------

    def draw_travel_time(self, link_like) -> float:
        factor = self.rng_travel.lognormal(mean=0.0, sigma=max(1e-6, float(link_like.sigma)))
        return max(0.0, float(link_like.T0) * factor)

    def draw_dwell_time(self, boarders: int, alighters: int) -> float:
        a  = float(self._get(self.dwell_params, "alpha", "alpha_s", default=5.0))
        bb = float(self._get(self.dwell_params, "beta_b", "beta_board", "beta_board_s", default=1.0))
        ba = float(self._get(self.dwell_params, "beta_a", "beta_alight", "beta_alight_s", default=1.0))
        sig= float(self._get(self.dwell_params, "sigma", "noise_sigma", "noise_sigma_s", default=0.5))
        noise = self.rng_dwell.normal(0.0, sig)
        return max(0.0, a + bb * boarders + ba * alighters + noise)

    # ---------- Headway risk proxy ----------

    def headway_risk_score(self, vehicle: Vehicle) -> float:
        occ = vehicle.onboard / max(1, vehicle.capacity)
        lateness = min(1.0, vehicle.headway_clock / max(1.0, self.policy.params.get("headway_planned_s", 900)))
        return 0.5 * occ + 0.5 * lateness

    # ---------- Corridor helpers ----------

    def _next_corridor_link(self, vehicle: Vehicle) -> Optional[Link]:
        options = [lk for lk in self.links_from.get(vehicle.loc_stop, []) if lk.link_type == "corridor"]
        if not options:
            return None
        def order_of(stop_id):
            return self.stop_by_id[stop_id].stop_order
        options.sort(key=lambda lk: order_of(lk.to_stop))
        return options[0]

    def _next_corridor_link_from(self, start_stop: str) -> Optional[Link]:
        options = [lk for lk in self.links_from.get(start_stop, []) if lk.link_type == "corridor"]
        if not options:
            return None
        def order_of(stop_id):
            return self.stop_by_id[stop_id].stop_order
        options.sort(key=lambda lk: order_of(lk.to_stop))
        return options[0]

    def _planned_time_for_next_links(self, start_stop: str, n_links: int) -> float:
        """Sum nominal T0 for the next n corridor links from start_stop."""
        t = 0.0; cur = start_stop
        for _ in range(max(0, int(n_links))):
            lk = self._next_corridor_link_from(cur)
            if lk is None:
                break
            t += float(lk.T0)
            cur = lk.to_stop
        return t

    def _effective_return_window_seconds(self, parent_stop_id: str, vehicle: Vehicle) -> float:
        """Compute min{T_max, time to traverse next N corridor links from parent}."""
        links = int(vehicle.return_window_links or 0)
        t_max = float(vehicle.return_window_time_s or 0.0)
        t_links = self._planned_time_for_next_links(parent_stop_id, links) if links > 0 else float("inf")
        a = t_max if t_max > 0 else float("inf")
        b = t_links if links > 0 else float("inf")
        eff = min(a, b)
        if eff == float("inf"):
            eff = 0.0
        return eff

    # ---------- Feasibility checks ----------
    
    def compute_feasibility(self, vehicle: Vehicle, req: Request) -> dict:
        is_corridor_pickup = (self.stop_by_id[req.o_stop].stop_type == "corridor")
        within_capacity = (vehicle.onboard + req.group_size) <= vehicle.capacity
    
        budget_ok = True
        return_window_ok = True
    
        # Only excursions (side pickups) are subject to budget/window checks
        if self.stop_by_id[req.o_stop].stop_type == "side":
            parent_id = self.stop_by_id[req.o_stop].parent_corr_stop_id
            parent = self.stop_by_id.get(parent_id) if parent_id else None
            veh_loc = self.stop_by_id.get(vehicle.loc_stop) if vehicle.loc_stop else None
    
            # Reachability: same corridor context and "ahead" in current direction
            reachable = False
            if parent is not None and veh_loc is not None:
                same_corridor = (parent.line_id == vehicle.line_id) or parent.in_overlap or ("AB" in str(parent.line_id).upper())
                nxt = self._next_corridor_link(vehicle)
                if nxt is not None:
                    cur_order = veh_loc.stop_order
                    to_order = self.stop_by_id[nxt.to_stop].stop_order
                    dir_sign = np.sign(to_order - cur_order)  # +1 forward, -1 backward
                    delta = parent.stop_order - cur_order
                    ahead = (dir_sign >= 0 and delta >= 0) or (dir_sign < 0 and delta <= 0)
                    reachable = same_corridor and ahead
    
            defaults = self.policy.params
            budget_s = vehicle.excursion_budget_s or float(defaults.get("excursion_budget_s", 0.0))
            links    = vehicle.return_window_links or int(defaults.get("return_window_links", 0))
            time_s   = vehicle.return_window_time_s or float(defaults.get("return_window_time_s", 0.0))
    
            # Find the excursion arc from the parent corridor stop to the side stop
            arcs = [ea for ea in self.excursion_arcs
                    if ea.from_corr_stop_id == (parent.stop_id if parent else None)
                    and ea.to_side_stop_id == req.o_stop]
    
            # Early reject only if there is no arc, it is unreachable, or budget is non-positive
            if len(arcs) == 0 or not reachable or budget_s <= 0:
                budget_ok = False
                return_window_ok = False
            else:
                ea = arcs[0]
                out_time  = self.draw_travel_time(ea) + float(ea.turn_penalty_s or 0.0)
                back_time = self.draw_travel_time(ea) + float(ea.turn_penalty_s or 0.0)
                excursion_time = out_time + back_time
    
                # Expected side dwell (mean) used at planning time
                exp_side_dwell = float(self._get(self.dwell_params, "alpha", "alpha_s", default=5.0))
    
                # Budget is independent of the window
                budget_ok = (excursion_time + exp_side_dwell) <= budget_s
    
                # Window logic: allow either limit to be present; require at least one
                has_time_window = (time_s > 0)
                has_link_window = (links >= 1)
    
                if not has_time_window and not has_link_window:
                    # No return window configured at all -> treat as infeasible under window constraint
                    return_window_ok = False
                else:
                    # Compute effective seconds from link window if present; otherwise use time_s
                    eff_window_s = 0.0
                    if has_link_window and parent is not None:
                        eff_window_s = self._effective_return_window_seconds(parent.stop_id, vehicle)
                    window_s = eff_window_s if (has_link_window and eff_window_s > 0) else time_s
                    return_window_ok = (excursion_time + exp_side_dwell) <= window_s
    
        return {
            "is_corridor_pickup": is_corridor_pickup,
            "within_capacity": within_capacity,
            "budget_ok": budget_ok,
            "return_window_ok": return_window_ok,
        }

    # ---------- Event processing ----------

    def schedule_event(self, time_val: float, typ: str, payload: dict, priority: int = 0):
        if time_val > self.horizon_s:
            return
        heapq.heappush(self.evq, Event(time=time_val, priority=priority, typ=typ, payload=payload))

    def log_event(self, **kwargs):
        self._events_wr.writerow(kwargs)
        t = kwargs.get('type')
        if t == 'depart_stop':
            tt = float(kwargs.get('ttravel', 0.0) or 0.0)
            self._stats['depart_count'] += 1
            self._stats['depart_ttravel_sum'] += tt
            self._stats['depart_ttravel_sqsum'] += tt*tt
        elif t == 'vehicle_depart':
            self._stats['vehicle_depart_count'] += 1

    def log_decision(self, **kwargs):
        self._dec_wr.writerow(kwargs)
        d = kwargs.get('decision')
        if d == 'accept':
            self._stats['accept_count'] += 1
        elif d == 'abandon':
            self._stats['abandon_count'] += 1

    # ---------- Deterministic initialization ----------

    def init_events(self):
        # First corridor stop per line
        line_first_stop: Dict[str, str] = {}
        for s in sorted(self.stops, key=lambda z: (z.line_id, z.stop_order)):
            if s.stop_type == "corridor":
                line_first_stop.setdefault(s.line_id, s.stop_id)

        # Build departure lookups
        dep_by_id: Dict[str, Departure] = {d.dep_id: d for d in self.departures}
        deps_by_line: Dict[str, List[Departure]] = {}
        for d in self.departures:
            deps_by_line.setdefault(d.line_id, []).append(d)
        for ln in deps_by_line:
            deps_by_line[ln].sort(key=lambda d: d.t_depart)

        # Deterministic vehicle -> departure mapping
        veh_has_dep_col = "dep_id" in self.vehicles_df.columns
        if veh_has_dep_col:
            # Use explicit dep_id mapping from vehicles.csv
            for _, r in self.vehicles_df.iterrows():
                vid = str(r["vehicle_id"]); line_id = str(r["line_id"])
                dep_id = str(r["dep_id"])
                dep = dep_by_id.get(dep_id)
                if dep is None:
                    continue
                if dep.t_depart > self.horizon_s:
                    continue
                v = self.vehicles[vid]
                v.loc_stop = line_first_stop.get(line_id, v.loc_stop)
                v.status = "reserved"
                self.schedule_event(dep.t_depart, "vehicle_depart", {"vehicle_id": vid, "dep_id": dep_id})
        else:
            # Fallback: zip vehicles and departures deterministically per line
            vehs_by_line: Dict[str, List[str]] = {}
            for _, r in self.vehicles_df.iterrows():
                vehs_by_line.setdefault(str(r["line_id"]), []).append(str(r["vehicle_id"]))
            for ln, dep_list in deps_by_line.items():
                vlist = sorted(vehs_by_line.get(ln, []))
                if not vlist:
                    continue
                # 1-1 if sizes match; otherwise reuse vehicles in round-robin (deterministic)
                for i, dep in enumerate(dep_list):
                    vid = vlist[i % len(vlist)]
                    if dep.t_depart > self.horizon_s:
                        continue
                    v = self.vehicles[vid]
                    v.loc_stop = line_first_stop.get(ln, v.loc_stop)
                    v.status = "reserved"
                    self.schedule_event(dep.t_depart, "vehicle_depart", {"vehicle_id": vid, "dep_id": dep.dep_id})

        # Requests arrivals
        for req in self.requests:
            if req.t_request <= self.horizon_s:
                self.schedule_event(req.t_request, "request_arrival", {"req_id": req.req_id})

    # --- Main loop ---

    def run(self):
        self.init_events()

        ctx = {"headway_risk_score": self.headway_risk_score}

        while self.evq:
            ev = heapq.heappop(self.evq)
            if ev.time > self.horizon_s:
                break
            self.t = ev.time
            if ev.typ == "vehicle_depart":
                self._handle_vehicle_depart(ev.payload)
            elif ev.typ == "arrive_stop":
                self._handle_arrive_stop(ev.payload)
            elif ev.typ == "dwell_complete":
                self._handle_dwell_complete(ev.payload)
            elif ev.typ == "request_arrival":
                self._handle_request_arrival(ev.payload, ctx)
            elif ev.typ == "request_timeout":
                self._handle_request_timeout(ev.payload)
            elif ev.typ == "excursion_depart":
                self._handle_excursion_depart(ev.payload)
            elif ev.typ == "excursion_arrive_side":
                self._handle_excursion_arrive_side(ev.payload)
            elif ev.typ == "excursion_return":
                self._handle_excursion_return(ev.payload)
            elif ev.typ == "excursion_rejoin":
                self._handle_excursion_rejoin(ev.payload)
            else:
                pass

        self._finalize_and_write()

    # --- Event handlers (corridor) ---

    def _handle_vehicle_depart(self, payload):
        vid = payload["vehicle_id"]
        veh = self.vehicles[vid]
        veh.status = "enroute"
        veh.headway_clock = 0.0
        link = self._next_corridor_link(veh)
        if link is None:
            return
        ttravel = self.draw_travel_time(link)
        veh.current_link = (link.from_stop, link.to_stop)
        self.schedule_event(self.t + ttravel, "arrive_stop", {"vehicle_id": vid, "to_stop": link.to_stop})
        self.log_event(time=self.t, type="vehicle_depart", vehicle_id=vid, from_stop=link.from_stop, to_stop=link.to_stop)

    def _handle_arrive_stop(self, payload):
        vid = payload["vehicle_id"]
        to_stop = payload["to_stop"]
        veh = self.vehicles[vid]
        veh.loc_stop = to_stop
        veh.status = "dwell"
        veh.current_link = None
        tdwell = self.draw_dwell_time(boarders=0, alighters=0)
        self.schedule_event(self.t + tdwell, "dwell_complete", {"vehicle_id": vid})
        self.log_event(time=self.t, type="arrive_stop", vehicle_id=vid, stop_id=to_stop, tdwell=tdwell)

    def _handle_dwell_complete(self, payload):
        vid = payload["vehicle_id"]
        veh = self.vehicles[vid]

        # If a pending excursion is queued for this corridor stop, do it now.
        tasks = self.pending_excursions.get(vid, [])
        task_idx = None
        for i, tsk in enumerate(tasks):
            if tsk["parent_stop"] == veh.loc_stop:
                task_idx = i
                break
            # Fallback: match by (line_id, stop_order)
            ps = self.stop_by_id.get(tsk["parent_stop"])
            vs = self.stop_by_id.get(veh.loc_stop) if veh.loc_stop else None
            if ps and vs and ps.line_id == vs.line_id and ps.stop_order == vs.stop_order:
                task_idx = i
                break

        if task_idx is not None:
            tsk = tasks.pop(task_idx)
            eff_window = self._effective_return_window_seconds(tsk["parent_stop"], veh)
            self.active_excursion[vid] = {
                "req_id": tsk["req_id"],
                "parent_stop": tsk["parent_stop"],
                "side_stop": tsk["side_stop"],
                "start_time": self.t,
                "deadline": (self.t + eff_window) if eff_window > 0 else (self.t + veh.return_window_time_s if veh.return_window_time_s > 0 else float("inf")),
                "out_time": None,
                "back_time": None,
                "distance_km": 0.0,
            }
            if veh.excursion_quota_left is not None and veh.excursion_quota_left > 0:
                veh.excursion_quota_left -= 1
            self.schedule_event(self.t, "excursion_depart", {"vehicle_id": vid})
            return

        veh.status = "enroute"
        link = self._next_corridor_link(veh)
        if link is None:
            return
        ttravel = self.draw_travel_time(link)
        veh.current_link = (link.from_stop, link.to_stop)
        veh.headway_clock += ttravel
        self.schedule_event(self.t + ttravel, "arrive_stop", {"vehicle_id": vid, "to_stop": link.to_stop})
        self.log_event(time=self.t, type="depart_stop", vehicle_id=vid, from_stop=link.from_stop, to_stop=link.to_stop, ttravel=ttravel)

    # --- Event handlers (excursions) ---

    def _handle_excursion_depart(self, payload):
        vid = payload["vehicle_id"]
        veh = self.vehicles[vid]
        ex = self.active_excursion.get(vid)
        if not ex:
            return

        parent = ex["parent_stop"]
        side = ex["side_stop"]
        ea = self.excursion_by_pair.get((parent, side))
        if ea is None:
            self._resume_corridor_after_excursion_abort(vid)
            return

        veh.status = "excursion"
        veh.current_link = (parent, side)
        out_time = self.draw_travel_time(ea) + float(ea.turn_penalty_s or 0.0)
        ex["out_time"] = out_time
        ex["distance_km"] = 2.0 * (float(ea.distance_m or 0.0) / 1000.0)
        self.schedule_event(self.t + out_time, "excursion_arrive_side", {"vehicle_id": vid})
        self.log_event(time=self.t, type="excursion_depart", vehicle_id=vid, from_stop=parent, to_stop=side, ttravel=out_time)

    def _handle_excursion_arrive_side(self, payload):
        vid = payload["vehicle_id"]
        veh = self.vehicles[vid]
        ex = self.active_excursion.get(vid)
        if not ex:
            return
        side = ex["side_stop"]
        veh.loc_stop = side
        veh.current_link = None
        tdwell = self.draw_dwell_time(boarders=1, alighters=0)
        self.schedule_event(self.t + tdwell, "excursion_return", {"vehicle_id": vid})
        self.log_event(time=self.t, type="excursion_arrive_side", vehicle_id=vid, stop_id=side, tdwell=tdwell)

    def _handle_excursion_return(self, payload):
        vid = payload["vehicle_id"]
        veh = self.vehicles[vid]
        ex = self.active_excursion.get(vid)
        if not ex:
            return
        parent = ex["parent_stop"]
        side = ex["side_stop"]
        ea = self.excursion_by_pair.get((parent, side))
        if ea is None:
            self._resume_corridor_after_excursion_abort(vid)
            return
        back_time = self.draw_travel_time(ea) + float(ea.turn_penalty_s or 0.0)
        ex["back_time"] = back_time
        veh.current_link = (side, parent)
        self.schedule_event(self.t + back_time, "excursion_rejoin", {"vehicle_id": vid})
        self.log_event(time=self.t, type="excursion_return", vehicle_id=vid, from_stop=side, to_stop=parent, ttravel=back_time)

    def _handle_excursion_rejoin(self, payload):
        vid = payload["vehicle_id"]
        veh = self.vehicles[vid]
        ex = self.active_excursion.get(vid)
        if not ex:
            return
        parent = ex["parent_stop"]

        excursion_duration = (self.t - ex["start_time"])
        veh.headway_clock += excursion_duration

        if self.t > ex.get("deadline", float("inf")):
            self._stats["missed_return_window_count"] += 1

        self._stats["excursions_km"] += max(0.0, ex.get("distance_km", 0.0))
        self._stats["excursions_count"] += 1

        self.active_excursion[vid] = None
        veh.status = "enroute"
        veh.loc_stop = parent
        veh.current_link = None

        link = self._next_corridor_link(veh)
        if link is None:
            return
        ttravel = self.draw_travel_time(link)
        veh.current_link = (link.from_stop, link.to_stop)
        self.schedule_event(self.t + ttravel, "arrive_stop", {"vehicle_id": vid, "to_stop": link.to_stop})
        self.log_event(time=self.t, type="depart_stop", vehicle_id=vid, from_stop=link.from_stop, to_stop=link.to_stop, ttravel=ttravel)

    def _resume_corridor_after_excursion_abort(self, vid: str):
        veh = self.vehicles[vid]
        self.active_excursion[vid] = None
        veh.status = "enroute"
        link = self._next_corridor_link(veh)
        if link is None:
            return
        ttravel = self.draw_travel_time(link)
        veh.current_link = (link.from_stop, link.to_stop)
        self.schedule_event(self.t + ttravel, "arrive_stop", {"vehicle_id": vid, "to_stop": link.to_stop})
        self.log_event(time=self.t, type="depart_stop", vehicle_id=vid, from_stop=link.from_stop, to_stop=link.to_stop, ttravel=ttravel)

    # --- Event handlers (requests) ---

    def _handle_request_arrival(self, payload, ctx):
        req_id = payload["req_id"]
        req = next(r for r in self.requests if r.req_id == req_id)
        self.req_status[req_id] = {"state": "waiting", "t_request": req.t_request}

        candidates = []
        for vid, veh in self.vehicles.items():
            # Only vehicles that are operating (already departed) are considered
            if veh.status not in ("enroute", "dwell", "excursion"):
                continue
            feas = self.compute_feasibility(veh, req)
            policy_accept, reason = self.policy.decide_accept(ctx, veh, req, feas)
            if policy_accept:
                # stochastic ETA around 120s to inject replicate variability
                eta = self.t + float(self.rng_travel.lognormal(mean=np.log(120.0), sigma=0.35))
                candidates.append((eta, vid, feas, reason))

        if candidates:
            candidates.sort(key=lambda z: z[0])
            eta, vid, feas, reason = candidates[0]
            self.req_status[req_id]["state"] = "accepted"
            self.req_status[req_id]["vehicle_id"] = vid
            self.req_status[req_id]["t_assign"] = self.t
            wait_s = max(0.0, eta - req.t_request)
            self._stats["wait_sum"] += wait_s
            self._stats["wait_n"] += 1

            if not feas["is_corridor_pickup"]:
                parent_id = self.stop_by_id[req.o_stop].parent_corr_stop_id
                if parent_id is not None:
                    parent = self.stop_by_id.get(parent_id)
                    if parent is not None:
                        canon_parent_id = self.corridor_stop_by_line_order.get((parent.line_id, parent.stop_order), parent.stop_id)
                    else:
                        canon_parent_id = parent_id
                    self.pending_excursions.setdefault(vid, []).append({
                        "req_id": req_id,
                        "side_stop": req.o_stop,
                        "parent_stop": canon_parent_id,  # canonicalized ID
                    })
            self.log_decision(time=self.t, req_id=req_id, vehicle_id=vid, decision="accept", reason=reason,
                              is_corridor_pickup=feas["is_corridor_pickup"], budget_ok=feas["budget_ok"],
                              return_window_ok=feas["return_window_ok"], within_capacity=feas["within_capacity"])
        else:
            patience = req.patience if req.patience is not None else 900.0
            self.schedule_event(self.t + patience, "request_timeout", {"req_id": req_id})
            # No counter change here (final outcome will be 'abandon' if it times out)

    def _handle_request_timeout(self, payload):
        req_id = payload["req_id"]
        st = self.req_status.get(req_id, {})
        if st.get("state") == "waiting":
            st["state"] = "abandoned"
            st["t_abandon"] = self.t
            t_request = st.get("t_request", self.t)
            self._stats["wait_sum"] += max(0.0, self.t - t_request)
            self._stats["wait_n"] += 1
            self.log_decision(time=self.t, req_id=req_id, vehicle_id=None, decision="abandon", reason="patience_expired",
                              is_corridor_pickup=None, budget_ok=None, return_window_ok=None, within_capacity=None)

    # ---------- Outputs ----------

    def _finalize_and_write(self):
        try:
            self._events_fp.flush(); self._events_fp.close()
        except Exception:
            pass
        try:
            self._dec_fp.flush(); self._dec_fp.close()
        except Exception:
            pass

        # Headway CV proxy from depart_stop legs
        depart_n = self._stats["depart_count"]
        if depart_n > 1:
            mean_t = self._stats["depart_ttravel_sum"] / depart_n
            var_t = max(0.0, (self._stats["depart_ttravel_sqsum"]/depart_n) - mean_t*mean_t)
            headway_cv = (var_t ** 0.5) / (mean_t + 1e-6)
        else:
            headway_cv = float('nan')

        # Final-outcome-based acceptance/abandon percentages
        final_n = self._stats["accept_count"] + self._stats["abandon_count"]
        accepted_pct = 100.0 * (self._stats["accept_count"] / max(1, final_n))
        abandon_pct  = 100.0 * (self._stats["abandon_count"] / max(1, final_n))

        # Mean wait across accepted and abandoned requests
        wait_mean_s = (self._stats["wait_sum"] / self._stats["wait_n"]) if self._stats["wait_n"] > 0 else float("nan")

        # Crude vehicle-km proxy
        vehicle_km = float(depart_n) * 0.5

        # Excursion share
        excursion_km_share = (self._stats["excursions_km"] / vehicle_km) if vehicle_km > 0 else 0.0
        excursion_km_share = max(0.0, min(1.0, excursion_km_share))

        # Missed return-window rate among realized excursions
        if self._stats["excursions_count"] > 0:
            missed_return_window_rate = self._stats["missed_return_window_count"] / float(self._stats["excursions_count"])
        else:
            missed_return_window_rate = 0.0

        kpis_df = pd.DataFrame([{
            "wait_mean_s": wait_mean_s,
            "accepted_pct": accepted_pct,
            "abandon_pct": abandon_pct,
            "vehicle_km": vehicle_km,
            "excursion_km_share": excursion_km_share,
            "dispatch_count": int(self._stats["vehicle_depart_count"]),
            "headway_cv": headway_cv,
            "missed_return_window_rate": missed_return_window_rate,
            "policy": self.policy_name,
            "replicate": self.replicate_id,
            "scenario": os.path.basename(self.scen_dir),
        }])
        kpis_df.to_csv(os.path.join(self.out_dir, "kpis.csv"), index=False)

        self.meta["ended_at"] = now_iso()
        write_json(self.meta, os.path.join(self.out_dir, "run_meta.json"))

# ------------- Driver -------------

def discover_scenarios(root: str) -> List[str]:
    out = []
    for dirpath, dirnames, filenames in os.walk(root):
        required = {"stops.csv","links.csv","departures.csv","dwell_params.yml","requests.csv","vehicles.csv","policy_params.yml","seed_manifest.json"}
        if required.issubset(set(filenames)):
            out.append(dirpath)
    return sorted(out)

def unzip_if_needed(zip_path: str, target_root: str) -> str:
    os.makedirs(target_root, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zf:
        should_extract = True
        for name in zf.namelist():
            if name.strip('/').count('/') >= 1 and name.endswith('/'):
                first_dir = name.split('/')[0]
                candidate = os.path.join(target_root, first_dir)
                if os.path.isdir(candidate):
                    should_extract = False
                    break
        if should_extract:
            zf.extractall(target_root)
    with zipfile.ZipFile(zip_path, 'r') as zf:
        names = zf.namelist()
    if any(n.startswith('scenarios/') for n in names):
        return os.path.join(target_root, 'scenarios')
    for root, dirs, files in os.walk(target_root):
        if 'stops.csv' in files:
            return os.path.dirname(root)
    return target_root

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--scenarios_root", type=str, default=None, help="Root directory containing scenario folders")
    ap.add_argument("--scenario_dir", type=str, default=None, help="Run a single scenario directory")
    ap.add_argument("--zip", type=str, default=None, help="Path to scenarios_all22.zip (optional; auto-detected if omitted)")
    ap.add_argument("--out_root", type=str, default=None, help="Root directory to write run manifests (default: ./runs next to this script)")
    ap.add_argument("--policies", type=str, default="baseline,myopic,slackaware", help="Comma-separated list")
    ap.add_argument("--replicates", type=int, default=30)
    ap.add_argument("--horizon_s", type=float, default=10800.0, help="Simulation horizon in seconds (default 3h)")
    try:
        args = ap.parse_args()
    except SystemExit:
        args = ap.parse_args([])

    try:
        script_dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        script_dir = os.getcwd()

    if not args.out_root:
        args.out_root = os.path.join(script_dir, "runs")

    if args.scenario_dir:
        scenarios = [args.scenario_dir]
    elif args.scenarios_root:
        scenarios = discover_scenarios(args.scenarios_root)
    else:
        zip_candidates = []
        if args.zip and os.path.isfile(args.zip):
            zip_candidates.append(args.zip)
        candidate1 = os.path.join(script_dir, "scenarios_all22.zip")
        if os.path.isfile(candidate1):
            zip_candidates.append(candidate1)
        candidate2 = os.path.join(os.getcwd(), "scenarios_all22.zip")
        if os.path.isfile(candidate2):
            zip_candidates.append(candidate2)

        if zip_candidates:
            zip_path = zip_candidates[0]
            scen_root = unzip_if_needed(zip_path, os.path.join(script_dir, "unzipped"))
            scenarios = discover_scenarios(scen_root)
        else:
            fallback_root = os.path.join(script_dir, "scenarios")
            scenarios = discover_scenarios(fallback_root) if os.path.isdir(fallback_root) else []

    if not scenarios:
        print("No scenarios found. Provide --scenario_dir, --scenarios_root, or place scenarios_all22.zip next to sim_runner.py.", file=sys.stderr)
        sys.exit(1)

    policies = [p.strip().lower() for p in (args.policies or '').split(',') if p.strip()]
    if not policies:
        policies = ['baseline','myopic','slackaware']

    total_runs = len(scenarios) * len(policies) * max(1, args.replicates)
    log_print(f"Total scenarios: {len(scenarios)} | Policies: {len(policies)} | Replicates: {args.replicates} | Planned runs: {total_runs}")
    prog = tqdm(total=total_runs, desc="All runs", leave=True) if _HAS_TQDM else None

    for scen in scenarios:
        scen_id = os.path.basename(scen.rstrip(os.sep))
        log_print(f"Starting scenario: {scen_id}")
        for pol in policies:
            if pol not in POLICY_REGISTRY:
                print(f"Unknown policy: {pol}", file=sys.stderr)
                continue
            log_print(f"  Policy: {pol}")
            for r in range(args.replicates):
                run_dir = os.path.join(args.out_root, scen_id, pol, f"{r:03d}")
                ensure_dir(run_dir)
                if not _HAS_TQDM:
                    log_print(f"    Replicate {r+1}/{args.replicates} ({scen_id} | {pol})")
                t0 = time.time()
                sim = Simulator(scen_dir=scen, policy_name=pol, replicate_id=r, out_dir=run_dir, horizon_s=args.horizon_s)
                sim.run()
                if prog:
                    prog.update(1)
                else:
                    dt_s = time.time() - t0
                    log_print(f"    Done replicate {r+1}/{args.replicates} in {dt_s:.1f}s ({scen_id} | {pol})")
        log_print(f"Finished scenario: {scen_id}")
    if prog:
        prog.close()

if __name__ == "__main__":
    main()
